import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision

from idspn import iDSPN, MSEObjective, clevr_project
from dspn import DSPN
from slot import SlotAttention


def create_pairs(a, b=None):
    if b is None:
        b = a
    LA = a.size(1)
    LB = b.size(1)
    a = a.unsqueeze(2).expand(-1, -1, LB, -1)
    b = b.unsqueeze(1).expand(-1, LA, -1, -1)
    return a, b


class FSPool(nn.Module):
    """
        Simplified version of featurewise sort pooling, without the option of variable-size sets through masking. From:
        FSPool: Learning Set Representations with Featurewise Sort Pooling.
        Yan Zhang, Jonathon Hare, Adam Prügel-Bennett
        https://arxiv.org/abs/1906.02795
        https://github.com/Cyanogenoid/fspool
    """

    def __init__(self, in_channels, set_size):
        super().__init__()
        self.weight = nn.Parameter(torch.zeros(set_size, in_channels))
        nn.init.kaiming_normal_(self.weight, mode='fan_out', nonlinearity='linear')

    def forward(self, x):
        x, _ = x.sort(dim=1)
        x = torch.einsum('nlc, lc -> nc', x, self.weight)
        return x


class FSEncoder(nn.Module):
    def __init__(self, input_channels, dim, output_channels, set_size):
        super().__init__()
        self.mlp = nn.Sequential(
            nn.Linear(input_channels, dim),
            nn.ReLU(inplace=True),
            nn.Linear(dim, output_channels),
        )
        self.pool = FSPool(output_channels, set_size)

    def forward(self, x):
        x = self.mlp(x)
        x = self.pool(x)
        return x


class RNFSEncoder(FSEncoder):
    def __init__(self, input_channels, dim, output_channels, set_size):
        super().__init__(2 * input_channels, dim, output_channels, set_size ** 2)

    def forward(self, x):
        x = torch.cat(create_pairs(x), dim=-1).flatten(1, 2)
        x = super().forward(x)
        return x


class ImageModel(nn.Module):
    """ ResNet18-based image encoder to turn an image into a feature vector """

    def __init__(self, latent, image_size):
        super().__init__()
        resnet = torchvision.models.resnet18()
        self.layers = nn.Sequential(*list(resnet.children())[:-2])
        resnet_output_dim = 512
        spatial_size = image_size // 32  # after resnet
        spatial_size = spatial_size // 2  # after strided conv
        self.end = nn.Sequential(
            nn.BatchNorm2d(resnet_output_dim),
            # now has 2x2 spatial size
            nn.Conv2d(resnet_output_dim, latent // spatial_size**2, 2, stride=2),
            # now has shape (n, latent // 4, 2, 2)
        )

    def forward(self, x):
        x = self.layers(x)
        x = self.end(x)
        return x.view(x.size(0), -1)


class DSPNModel(nn.Module):
    def __init__(self, d_in, d_hid, d_latent, set_size, lr=1, iters=20, momentum=0.9, grad_clip=None, input_encoder='rnfs', decoder_encoder='fs', use_starting_set=False, image_input=False, image_size=None, implicit=False):
        super().__init__()
        self.lr = lr
        self.iters = iters
        self.implicit = implicit
        
        if image_input:
            self.enc = ImageModel(d_latent, image_size)
        else:
            input_encoder_cls = RNFSEncoder if input_encoder == 'rnfs' else FSEncoder
            self.enc = input_encoder_cls(d_in, d_hid, d_latent, set_size)
        
        decoder_encoder_cls = RNFSEncoder if decoder_encoder == 'rnfs' else FSEncoder
        decoder_set_encoder = decoder_encoder_cls(d_in, d_hid, d_latent, set_size)

        if self.implicit:
            self.dspn = iDSPN(
                objective=MSEObjective(decoder_set_encoder, regularized=use_starting_set),
                optim_f=lambda p: torch.optim.SGD(p, lr=self.lr, momentum=momentum, nesterov=momentum > 0),
                optim_iters=self.iters,
                set_channels=d_in,
                set_size=set_size,
                grad_clip=grad_clip,
                use_starting_set=use_starting_set,
            )
        else:
            self.dspn = DSPN(
                encoder=decoder_set_encoder,
                set_channels=d_in,
                max_set_size=set_size,
                channels=d_latent,
                iters=self.iters,
                lr=self.lr,
            )

    def forward(self, x):
        z = self.enc(x)
        x, set_grad = self.dspn(z)
        return x, set_grad


class SlotAttentionModel(nn.Module):
    def __init__(self, d_in, d_hid, set_size):
        super().__init__()
        self.preproj = nn.Linear(d_in, d_hid)
        self.slot_attention = SlotAttention(set_size, d_hid, hidden_dim=d_hid, use_ln=True, iters=5)
        self.postproj = nn.Linear(d_hid, d_in)
    
    def forward(self, x):
        # input is a set, slot attention only needs to learn identity without bottleneck
        x = self.preproj(x)
        x = self.slot_attention(x, x)
        x = self.postproj(x)
        dummy_set_grad = torch.zeros(1, 1, 1, device=x.device)
        return x, dummy_set_grad
